{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 174,
   "metadata": {},
   "outputs": [],
   "source": [
    "import tensorflow as tf\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 175,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Found 261 files belonging to 3 classes.\n",
      "Using 209 files for training.\n",
      "Found 261 files belonging to 3 classes.\n",
      "Using 52 files for validation.\n"
     ]
    }
   ],
   "source": [
    "data_dir = 'data/train/'\n",
    "\n",
    "train_ds = tf.keras.utils.image_dataset_from_directory(\n",
    "  data_dir,\n",
    "  color_mode='grayscale',\n",
    "  validation_split=0.2,\n",
    "  subset=\"training\",\n",
    "  seed=123,\n",
    "  image_size=(28, 28),\n",
    "  batch_size=10)\n",
    "\n",
    "valid_ds = tf.keras.utils.image_dataset_from_directory(\n",
    "  data_dir,\n",
    "  color_mode='grayscale',\n",
    "  validation_split=0.2,\n",
    "  subset=\"validation\",\n",
    "  seed=123,\n",
    "  image_size=(28, 28),\n",
    "  batch_size=10)\n",
    "\n",
    "normalization_layer = tf.keras.layers.Rescaling(1./255)\n",
    "normalized_train_ds = train_ds.map(lambda x, y: (normalization_layer(x), y))\n",
    "normalized_valid_ds = valid_ds.map(lambda x, y: (normalization_layer(x), y))\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 176,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "['0', '1', '10']\n"
     ]
    }
   ],
   "source": [
    "class_names = train_ds.class_names\n",
    "print(class_names)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 177,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Show the first few images\n",
    "# plt.figure(figsize=(10, 10))\n",
    "# for images, labels in train_ds.take(1):\n",
    "#   for i in range(2):\n",
    "#     ax = plt.subplot(3, 3, i + 1)\n",
    "#     plt.imshow(images[i].numpy().astype(\"uint8\"))\n",
    "#     plt.title(class_names[labels[i]])\n",
    "#     plt.axis(\"off\")\n",
    "#     plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 178,
   "metadata": {},
   "outputs": [],
   "source": [
    "model = tf.keras.models.Sequential([\n",
    "  tf.keras.layers.Flatten(input_shape=(28, 28)),\n",
    "  tf.keras.layers.Dense(128, activation='relu'),\n",
    "  tf.keras.layers.Dropout(0.2),\n",
    "  tf.keras.layers.Dense(len(class_names))\n",
    "])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 179,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Model: \"sequential_18\"\n",
      "_________________________________________________________________\n",
      " Layer (type)                Output Shape              Param #   \n",
      "=================================================================\n",
      " flatten_18 (Flatten)        (None, 784)               0         \n",
      "                                                                 \n",
      " dense_36 (Dense)            (None, 128)               100480    \n",
      "                                                                 \n",
      " dropout_18 (Dropout)        (None, 128)               0         \n",
      "                                                                 \n",
      " dense_37 (Dense)            (None, 3)                 387       \n",
      "                                                                 \n",
      "=================================================================\n",
      "Total params: 100,867\n",
      "Trainable params: 100,867\n",
      "Non-trainable params: 0\n",
      "_________________________________________________________________\n"
     ]
    }
   ],
   "source": [
    "model.summary()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 180,
   "metadata": {},
   "outputs": [],
   "source": [
    "loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)\n",
    "model.compile(optimizer='adam',\n",
    "              loss=loss_fn,\n",
    "              metrics=['accuracy'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 181,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 1/5\n",
      "21/21 [==============================] - 0s 7ms/step - loss: 0.3889 - accuracy: 0.8565 - val_loss: 0.1763 - val_accuracy: 0.9423\n",
      "Epoch 2/5\n",
      "21/21 [==============================] - 0s 4ms/step - loss: 0.1107 - accuracy: 0.9569 - val_loss: 0.1409 - val_accuracy: 0.9615\n",
      "Epoch 3/5\n",
      "21/21 [==============================] - 0s 4ms/step - loss: 0.0420 - accuracy: 0.9952 - val_loss: 0.1148 - val_accuracy: 0.9615\n",
      "Epoch 4/5\n",
      "21/21 [==============================] - 0s 4ms/step - loss: 0.0251 - accuracy: 0.9952 - val_loss: 0.1165 - val_accuracy: 0.9615\n",
      "Epoch 5/5\n",
      "21/21 [==============================] - 0s 4ms/step - loss: 0.0156 - accuracy: 1.0000 - val_loss: 0.1182 - val_accuracy: 0.9615\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "<keras.callbacks.History at 0x25759086810>"
      ]
     },
     "execution_count": 181,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model.fit(\n",
    "  normalized_train_ds,\n",
    "  validation_data=normalized_valid_ds,\n",
    "  epochs=5\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 182,
   "metadata": {},
   "outputs": [],
   "source": [
    "import tf2onnx\n",
    "spec = (tf.TensorSpec((None, 28, 28, 1), tf.float32, name=\"input\"),)\n",
    "output_path = \"onnx/test010.onnx\"\n",
    "\n",
    "model_proto, _ = tf2onnx.convert.from_keras(model, input_signature=spec, opset=13, output_path=output_path)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.3"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
